{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "69637b5f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: kfp==1.4.0 in ./.local/lib/python3.8/site-packages (from -r requirement.txt (line 1)) (1.4.0)\n",
      "Requirement already satisfied: numpy in ./.local/lib/python3.8/site-packages (from -r requirement.txt (line 2)) (1.21.2)\n",
      "Requirement already satisfied: keras in ./.local/lib/python3.8/site-packages (from -r requirement.txt (line 3)) (2.6.0)\n",
      "Requirement already satisfied: tqdm in ./.local/lib/python3.8/site-packages (from -r requirement.txt (line 4)) (4.62.3)\n",
      "Requirement already satisfied: config in ./.local/lib/python3.8/site-packages (from -r requirement.txt (line 5)) (0.5.1)\n",
      "Requirement already satisfied: sklearn in ./.local/lib/python3.8/site-packages (from -r requirement.txt (line 6)) (0.0)\n",
      "Requirement already satisfied: kubernetes<12.0.0,>=8.0.0 in ./.local/lib/python3.8/site-packages (from kfp==1.4.0->-r requirement.txt (line 1)) (11.0.0)\n",
      "Requirement already satisfied: Deprecated in /opt/conda/lib/python3.8/site-packages (from kfp==1.4.0->-r requirement.txt (line 1)) (1.2.12)\n",
      "Requirement already satisfied: kfp-pipeline-spec<0.2.0,>=0.1.0 in ./.local/lib/python3.8/site-packages (from kfp==1.4.0->-r requirement.txt (line 1)) (0.1.9)\n",
      "Requirement already satisfied: requests-toolbelt>=0.8.0 in /opt/conda/lib/python3.8/site-packages (from kfp==1.4.0->-r requirement.txt (line 1)) (0.9.1)\n",
      "Requirement already satisfied: kfp-server-api<2.0.0,>=1.1.2 in ./.local/lib/python3.8/site-packages (from kfp==1.4.0->-r requirement.txt (line 1)) (1.6.0)\n",
      "Requirement already satisfied: docstring-parser>=0.7.3 in ./.local/lib/python3.8/site-packages (from kfp==1.4.0->-r requirement.txt (line 1)) (0.10)\n",
      "Requirement already satisfied: PyYAML>=5.3 in /opt/conda/lib/python3.8/site-packages (from kfp==1.4.0->-r requirement.txt (line 1)) (5.4.1)\n",
      "Requirement already satisfied: click in /opt/conda/lib/python3.8/site-packages (from kfp==1.4.0->-r requirement.txt (line 1)) (7.1.2)\n",
      "Requirement already satisfied: jsonschema>=3.0.1 in /opt/conda/lib/python3.8/site-packages (from kfp==1.4.0->-r requirement.txt (line 1)) (3.2.0)\n",
      "Requirement already satisfied: google-cloud-storage>=1.13.0 in /opt/conda/lib/python3.8/site-packages (from kfp==1.4.0->-r requirement.txt (line 1)) (1.37.1)\n",
      "Requirement already satisfied: cloudpickle in /opt/conda/lib/python3.8/site-packages (from kfp==1.4.0->-r requirement.txt (line 1)) (1.6.0)\n",
      "Requirement already satisfied: fire>=0.3.1 in ./.local/lib/python3.8/site-packages (from kfp==1.4.0->-r requirement.txt (line 1)) (0.4.0)\n",
      "Requirement already satisfied: tabulate in /opt/conda/lib/python3.8/site-packages (from kfp==1.4.0->-r requirement.txt (line 1)) (0.8.9)\n",
      "Requirement already satisfied: strip-hints in /opt/conda/lib/python3.8/site-packages (from kfp==1.4.0->-r requirement.txt (line 1)) (0.1.9)\n",
      "Requirement already satisfied: google-auth>=1.6.1 in /opt/conda/lib/python3.8/site-packages (from kfp==1.4.0->-r requirement.txt (line 1)) (1.28.1)\n",
      "Requirement already satisfied: termcolor in ./.local/lib/python3.8/site-packages (from fire>=0.3.1->kfp==1.4.0->-r requirement.txt (line 1)) (1.1.0)\n",
      "Requirement already satisfied: six in /opt/conda/lib/python3.8/site-packages (from fire>=0.3.1->kfp==1.4.0->-r requirement.txt (line 1)) (1.15.0)\n",
      "Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.8/site-packages (from google-auth>=1.6.1->kfp==1.4.0->-r requirement.txt (line 1)) (0.2.8)\n",
      "Requirement already satisfied: rsa<5,>=3.1.4 in /opt/conda/lib/python3.8/site-packages (from google-auth>=1.6.1->kfp==1.4.0->-r requirement.txt (line 1)) (4.7.2)\n",
      "Requirement already satisfied: cachetools<5.0,>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from google-auth>=1.6.1->kfp==1.4.0->-r requirement.txt (line 1)) (4.2.1)\n",
      "Requirement already satisfied: setuptools>=40.3.0 in /opt/conda/lib/python3.8/site-packages (from google-auth>=1.6.1->kfp==1.4.0->-r requirement.txt (line 1)) (49.6.0.post20210108)\n",
      "Requirement already satisfied: requests<3.0.0dev,>=2.18.0 in ./.local/lib/python3.8/site-packages (from google-cloud-storage>=1.13.0->kfp==1.4.0->-r requirement.txt (line 1)) (2.26.0)\n",
      "Requirement already satisfied: google-cloud-core<2.0dev,>=1.4.1 in /opt/conda/lib/python3.8/site-packages (from google-cloud-storage>=1.13.0->kfp==1.4.0->-r requirement.txt (line 1)) (1.6.0)\n",
      "Requirement already satisfied: google-resumable-media<2.0dev,>=1.2.0 in /opt/conda/lib/python3.8/site-packages (from google-cloud-storage>=1.13.0->kfp==1.4.0->-r requirement.txt (line 1)) (1.2.0)\n",
      "Requirement already satisfied: google-api-core<2.0.0dev,>=1.21.0 in /opt/conda/lib/python3.8/site-packages (from google-cloud-core<2.0dev,>=1.4.1->google-cloud-storage>=1.13.0->kfp==1.4.0->-r requirement.txt (line 1)) (1.26.3)\n",
      "Requirement already satisfied: packaging>=14.3 in /opt/conda/lib/python3.8/site-packages (from google-api-core<2.0.0dev,>=1.21.0->google-cloud-core<2.0dev,>=1.4.1->google-cloud-storage>=1.13.0->kfp==1.4.0->-r requirement.txt (line 1)) (20.9)\n",
      "Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.6.0 in /opt/conda/lib/python3.8/site-packages (from google-api-core<2.0.0dev,>=1.21.0->google-cloud-core<2.0dev,>=1.4.1->google-cloud-storage>=1.13.0->kfp==1.4.0->-r requirement.txt (line 1)) (1.53.0)\n",
      "Requirement already satisfied: pytz in /opt/conda/lib/python3.8/site-packages (from google-api-core<2.0.0dev,>=1.21.0->google-cloud-core<2.0dev,>=1.4.1->google-cloud-storage>=1.13.0->kfp==1.4.0->-r requirement.txt (line 1)) (2021.1)\n",
      "Requirement already satisfied: protobuf>=3.12.0 in /opt/conda/lib/python3.8/site-packages (from google-api-core<2.0.0dev,>=1.21.0->google-cloud-core<2.0dev,>=1.4.1->google-cloud-storage>=1.13.0->kfp==1.4.0->-r requirement.txt (line 1)) (3.15.7)\n",
      "Requirement already satisfied: google-crc32c<2.0dev,>=1.0 in /opt/conda/lib/python3.8/site-packages (from google-resumable-media<2.0dev,>=1.2.0->google-cloud-storage>=1.13.0->kfp==1.4.0->-r requirement.txt (line 1)) (1.1.2)\n",
      "Requirement already satisfied: cffi>=1.0.0 in /opt/conda/lib/python3.8/site-packages (from google-crc32c<2.0dev,>=1.0->google-resumable-media<2.0dev,>=1.2.0->google-cloud-storage>=1.13.0->kfp==1.4.0->-r requirement.txt (line 1)) (1.14.5)\n",
      "Requirement already satisfied: pycparser in /opt/conda/lib/python3.8/site-packages (from cffi>=1.0.0->google-crc32c<2.0dev,>=1.0->google-resumable-media<2.0dev,>=1.2.0->google-cloud-storage>=1.13.0->kfp==1.4.0->-r requirement.txt (line 1)) (2.20)\n",
      "Requirement already satisfied: pyrsistent>=0.14.0 in /opt/conda/lib/python3.8/site-packages (from jsonschema>=3.0.1->kfp==1.4.0->-r requirement.txt (line 1)) (0.17.3)\n",
      "Requirement already satisfied: attrs>=17.4.0 in /opt/conda/lib/python3.8/site-packages (from jsonschema>=3.0.1->kfp==1.4.0->-r requirement.txt (line 1)) (20.3.0)\n",
      "Requirement already satisfied: python-dateutil in /opt/conda/lib/python3.8/site-packages (from kfp-server-api<2.0.0,>=1.1.2->kfp==1.4.0->-r requirement.txt (line 1)) (2.8.1)\n",
      "Requirement already satisfied: urllib3>=1.15 in /opt/conda/lib/python3.8/site-packages (from kfp-server-api<2.0.0,>=1.1.2->kfp==1.4.0->-r requirement.txt (line 1)) (1.26.4)\n",
      "Requirement already satisfied: certifi in /opt/conda/lib/python3.8/site-packages (from kfp-server-api<2.0.0,>=1.1.2->kfp==1.4.0->-r requirement.txt (line 1)) (2020.12.5)\n",
      "Requirement already satisfied: requests-oauthlib in /opt/conda/lib/python3.8/site-packages (from kubernetes<12.0.0,>=8.0.0->kfp==1.4.0->-r requirement.txt (line 1)) (1.3.0)\n",
      "Requirement already satisfied: websocket-client!=0.40.0,!=0.41.*,!=0.42.*,>=0.32.0 in /opt/conda/lib/python3.8/site-packages (from kubernetes<12.0.0,>=8.0.0->kfp==1.4.0->-r requirement.txt (line 1)) (0.58.0)\n",
      "Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.8/site-packages (from packaging>=14.3->google-api-core<2.0.0dev,>=1.21.0->google-cloud-core<2.0dev,>=1.4.1->google-cloud-storage>=1.13.0->kfp==1.4.0->-r requirement.txt (line 1)) (2.4.7)\n",
      "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /opt/conda/lib/python3.8/site-packages (from pyasn1-modules>=0.2.1->google-auth>=1.6.1->kfp==1.4.0->-r requirement.txt (line 1)) (0.4.8)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.8/site-packages (from requests<3.0.0dev,>=2.18.0->google-cloud-storage>=1.13.0->kfp==1.4.0->-r requirement.txt (line 1)) (2.10)\n",
      "Requirement already satisfied: charset-normalizer~=2.0.0 in ./.local/lib/python3.8/site-packages (from requests<3.0.0dev,>=2.18.0->google-cloud-storage>=1.13.0->kfp==1.4.0->-r requirement.txt (line 1)) (2.0.4)\n",
      "Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.8/site-packages (from sklearn->-r requirement.txt (line 6)) (0.24.1)\n",
      "Requirement already satisfied: wrapt<2,>=1.10 in /opt/conda/lib/python3.8/site-packages (from Deprecated->kfp==1.4.0->-r requirement.txt (line 1)) (1.12.1)\n",
      "Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.8/site-packages (from requests-oauthlib->kubernetes<12.0.0,>=8.0.0->kfp==1.4.0->-r requirement.txt (line 1)) (3.1.0)\n",
      "Requirement already satisfied: scipy>=0.19.1 in /opt/conda/lib/python3.8/site-packages (from scikit-learn->sklearn->-r requirement.txt (line 6)) (1.6.2)\n",
      "Requirement already satisfied: joblib>=0.11 in /opt/conda/lib/python3.8/site-packages (from scikit-learn->sklearn->-r requirement.txt (line 6)) (1.0.1)\n",
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from scikit-learn->sklearn->-r requirement.txt (line 6)) (2.1.0)\n",
      "Requirement already satisfied: wheel in /opt/conda/lib/python3.8/site-packages (from strip-hints->kfp==1.4.0->-r requirement.txt (line 1)) (0.36.2)\n"
     ]
    }
   ],
   "source": [
    "with open(\"requirement.txt\", \"w\") as f:\n",
    "    f.write(\"kfp==1.4.0\\n\")\n",
    "    f.write(\"numpy\\n\")\n",
    "    f.write(\"keras\\n\")\n",
    "    f.write(\"tqdm\\n\")\n",
    "    f.write(\"config\\n\")\n",
    "    f.write(\"sklearn\\n\")\n",
    "!pip install -r requirement.txt  --upgrade --user\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "d84bac6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import NamedTuple\n",
    "import numpy\n",
    "def load_data(log_folder:str)->NamedTuple('Outputs', [('start_time_string',str)]):\n",
    "    import numpy as np\n",
    "    import time\n",
    "    import sys\n",
    "    print(\"import done...\")\n",
    "    start = time.time()\n",
    "    data= np.load(\"triplet-data.npz\")\n",
    "    sys.path.append(\"./\")\n",
    "    \n",
    "#    from config import img_size, channel, faces_data_dir, FREEZE_LAYERS, classify, facenet_weight_path   \n",
    "#    from inception_resnet_v1 import InceptionResNetV1\n",
    "#    from utils import scatter \n",
    "    X_train, X_test = data['arr_0'], data['arr_1']\n",
    "    print(X_train.shape, X_test.shape)\n",
    "    print(\"Saving data...\")\n",
    "    #print(X_train)\n",
    "    #print(X_test)\n",
    "    np.savez_compressed('/persist-log/triplet-data.npz', X_train, X_test)\n",
    "    print('Save complete ...')\n",
    "    start_time_string=str(start) #type is string\n",
    "    return [start_time_string]\n",
    "\n",
    "\n",
    "def distributed_training_worker1(start_time_string:str)->NamedTuple('Outputs',[('model_path',str)]):\n",
    "    import numpy as np\n",
    "    import sys\n",
    "    import time\n",
    "    import tensorflow as tf\n",
    "    import json\n",
    "    import os\n",
    "    sys.path.append(\"./\")\n",
    "    sys.path.append(\"/persist-log\")\n",
    "    from config import img_size, channel, faces_data_dir, FREEZE_LAYERS, classify, facenet_weight_path\n",
    "    from inception_resnet_v1 import InceptionResNetV1\n",
    "    from itertools import permutations\n",
    "    from tqdm import tqdm\n",
    "    from tensorflow.keras import backend as K\n",
    "    from sklearn.manifold import TSNE\n",
    "    \n",
    "    #load data from pvc in the container\n",
    "    data = np.load('/persist-log/triplet-data.npz')\n",
    "    X_train, X_test = data['arr_0'], data['arr_1']\n",
    "\n",
    "    def training_model(in_shape,freeze_layers,weights_path):\n",
    "\n",
    "        def create_base_network(in_dims,freeze_layers,weights_path):\n",
    "            model = InceptionResNetV1(input_shape=in_dims, weights_path=weights_path)\n",
    "            print('layer length: ', len(model.layers))\n",
    "            for layer in model.layers[:freeze_layers]:\n",
    "                layer.trainable = False\n",
    "            for layer in model.layers[freeze_layers:]:\n",
    "                layer.trainable = True\n",
    "            return model\n",
    "        \n",
    "        def triplet_loss(y_true,y_pred,alpha=0.4):\n",
    "            total_lenght = y_pred.shape.as_list()[-1]\n",
    "            anchor = y_pred[:, 0:int(total_lenght * 1 / 3)]\n",
    "            positive = y_pred[:, int(total_lenght * 1 / 3):int(total_lenght * 2 / 3)]\n",
    "            negative = y_pred[:, int(total_lenght * 2 / 3):int(total_lenght * 3 / 3)]\n",
    "            # distance between the anchor and the positive\n",
    "            pos_dist = K.sum(K.square(anchor - positive), axis=1)\n",
    "            # distance between the anchor and the negative\n",
    "            neg_dist = K.sum(K.square(anchor - negative), axis=1)\n",
    "            # compute loss\n",
    "            basic_loss = pos_dist - neg_dist + alpha\n",
    "            loss = K.maximum(basic_loss, 0.0)\n",
    "            return loss\n",
    "        # define triplet input layers\n",
    "        anchor_input = tf.keras.layers.Input(in_shape, name='anchor_input')\n",
    "        positive_input = tf.keras.layers.Input(in_shape, name='positive_input')\n",
    "        negative_input = tf.keras.layers.Input(in_shape, name='negative_input')\n",
    "        Shared_DNN = create_base_network(in_shape, freeze_layers, weights_path)\n",
    "        # Shared_DNN.summary()\n",
    "        # encoded inputs\n",
    "        encoded_anchor = Shared_DNN(anchor_input)\n",
    "        encoded_positive = Shared_DNN(positive_input)\n",
    "        encoded_negative = Shared_DNN(negative_input)\n",
    "        # output\n",
    "        merged_vector = tf.keras.layers.concatenate([encoded_anchor, encoded_positive, encoded_negative],axis=-1,name='merged_layer')\n",
    "        model = tf.keras.Model(inputs=[anchor_input, positive_input, negative_input], outputs=merged_vector)\n",
    "        model.compile(\n",
    "            optimizer=adam_optim,\n",
    "            loss=triplet_loss,\n",
    "        )\n",
    "        return model\n",
    "    \n",
    "    \n",
    "    os.environ['TF_CONFIG'] = json.dumps({'cluster': {'worker': [\"pipeline-worker-1:3000\",\"pipeline-worker-2:3000\",\"pipeline-worker-3:3000\"]},'task': {'type': 'worker', 'index': 0}})\n",
    "    #os.environ['TF_CONFIG'] = json.dumps({'cluster': {'worker': [\"pipeline-worker-1:3000\"]},'task': {'type': 'worker', 'index': 0}})\n",
    "\n",
    "\n",
    "    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(\n",
    "        tf.distribute.experimental.CollectiveCommunication.RING)\n",
    "    NUM_WORKERS = strategy.num_replicas_in_sync\n",
    "    print('=================\\r\\nWorkers: ' + str(NUM_WORKERS) + '\\r\\n=================\\r\\n')\n",
    "    learn_rate = 0.0001 + NUM_WORKERS * 0.00006\n",
    "    adam_optim = tf.keras.optimizers.Adam(lr=learn_rate)\n",
    "    batch_size = 32* NUM_WORKERS\n",
    "    model_path='/persist-log/weight_tfdl.h5'\n",
    "    print(model_path)\n",
    "    callbacks = [tf.keras.callbacks.ModelCheckpoint(model_path, save_weights_only=True, verbose=1)]\n",
    "    #X_train=np.array(X_train)\n",
    "    #print(type(X_train))\n",
    "    with strategy.scope():\n",
    "        Anchor = X_train[:, 0, :].reshape(-1, img_size, img_size, channel)\n",
    "        Positive = X_train[:, 1, :].reshape(-1, img_size, img_size, channel)\n",
    "        Negative = X_train[:, 2, :].reshape(-1, img_size, img_size, channel)\n",
    "        Y_dummy = np.empty(Anchor.shape[0])\n",
    "        model = training_model((img_size, img_size, channel), FREEZE_LAYERS, facenet_weight_path)\n",
    "        \n",
    "    model.fit(x=[Anchor, Positive, Negative],\n",
    "        y=Y_dummy,\n",
    "        # Anchor_test = X_test[:, 0, :].reshape(-1, img_size, img_size, channel)\n",
    "        # Positive_test = X_test[:, 1, :].reshape(-1, img_size, img_size, channel)\n",
    "        # Negative_test = X_test[:, 2, :].reshape(-1, img_size, img_size, channel)\n",
    "        # Y_dummy = np.empty(Anchor.shape[0])\n",
    "        # Y_dummy2 = np.empty((Anchor_test.shape[0], 1))\n",
    "        # validation_data=([Anchor_test,Positive_test,Negative_test],Y_dummy2),\n",
    "        # validation_split=0.2,\n",
    "        batch_size=batch_size,  # old setting: 32\n",
    "        # steps_per_epoch=(X_train.shape[0] // batch_size) + 1,\n",
    "        epochs=10,\n",
    "        callbacks=callbacks\n",
    "        )  \n",
    "    end = time.time()\n",
    "    start_time_float=float(start_time_string)\n",
    "    print('execution time = ', ((end - start_time_float)/60))\n",
    "    return [model_path]\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "97bad8e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def distributed_training_worker2(start_time_string:str)->NamedTuple('Outputs',[('model_path_work2',str)]):\n",
    "    import numpy as np\n",
    "    import sys\n",
    "    import time\n",
    "    import tensorflow as tf\n",
    "    import json\n",
    "    import os\n",
    "    sys.path.append(\"./\")\n",
    "    sys.path.append(\"/persist-log\")\n",
    "    from config import img_size, channel, faces_data_dir, FREEZE_LAYERS, classify, facenet_weight_path\n",
    "    from inception_resnet_v1 import InceptionResNetV1\n",
    "    from itertools import permutations\n",
    "    from tqdm import tqdm\n",
    "    from tensorflow.keras import backend as K\n",
    "    from sklearn.manifold import TSNE\n",
    "    \n",
    "    #load data from pvc in the container\n",
    "    data = np.load('/persist-log/triplet-data.npz')\n",
    "    X_train, X_test = data['arr_0'], data['arr_1']\n",
    "\n",
    "    def training_model(in_shape,freeze_layers,weights_path):\n",
    "\n",
    "        def create_base_network(in_dims,freeze_layers,weights_path):\n",
    "            model = InceptionResNetV1(input_shape=in_dims, weights_path=weights_path)\n",
    "            print('layer length: ', len(model.layers))\n",
    "            for layer in model.layers[:freeze_layers]:\n",
    "                layer.trainable = False\n",
    "            for layer in model.layers[freeze_layers:]:\n",
    "                layer.trainable = True\n",
    "            return model\n",
    "        \n",
    "        def triplet_loss(y_true,y_pred,alpha=0.4):\n",
    "            total_lenght = y_pred.shape.as_list()[-1]\n",
    "            anchor = y_pred[:, 0:int(total_lenght * 1 / 3)]\n",
    "            positive = y_pred[:, int(total_lenght * 1 / 3):int(total_lenght * 2 / 3)]\n",
    "            negative = y_pred[:, int(total_lenght * 2 / 3):int(total_lenght * 3 / 3)]\n",
    "            # distance between the anchor and the positive\n",
    "            pos_dist = K.sum(K.square(anchor - positive), axis=1)\n",
    "            # distance between the anchor and the negative\n",
    "            neg_dist = K.sum(K.square(anchor - negative), axis=1)\n",
    "            # compute loss\n",
    "            basic_loss = pos_dist - neg_dist + alpha\n",
    "            loss = K.maximum(basic_loss, 0.0)\n",
    "            return loss\n",
    "        # define triplet input layers\n",
    "        anchor_input = tf.keras.layers.Input(in_shape, name='anchor_input')\n",
    "        positive_input = tf.keras.layers.Input(in_shape, name='positive_input')\n",
    "        negative_input = tf.keras.layers.Input(in_shape, name='negative_input')\n",
    "        Shared_DNN = create_base_network(in_shape, freeze_layers, weights_path)\n",
    "        # Shared_DNN.summary()\n",
    "        # encoded inputs\n",
    "        encoded_anchor = Shared_DNN(anchor_input)\n",
    "        encoded_positive = Shared_DNN(positive_input)\n",
    "        encoded_negative = Shared_DNN(negative_input)\n",
    "        # output\n",
    "        merged_vector = tf.keras.layers.concatenate([encoded_anchor, encoded_positive, encoded_negative],axis=-1,name='merged_layer')\n",
    "        model = tf.keras.Model(inputs=[anchor_input, positive_input, negative_input], outputs=merged_vector)\n",
    "        model.compile(\n",
    "            optimizer=adam_optim,\n",
    "            loss=triplet_loss,\n",
    "        )\n",
    "        return model\n",
    "    \n",
    "    \n",
    "    os.environ['TF_CONFIG'] = json.dumps({'cluster': {'worker': [\"pipeline-worker-1:3000\",\"pipeline-worker-2:3000\",\"pipeline-worker-3:3000\"]},'task': {'type': 'worker', 'index': 1}})\n",
    "\n",
    "\n",
    "    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(\n",
    "        tf.distribute.experimental.CollectiveCommunication.RING)\n",
    "    NUM_WORKERS = strategy.num_replicas_in_sync\n",
    "    print('=================\\r\\nWorkers: ' + str(NUM_WORKERS) + '\\r\\n=================\\r\\n')\n",
    "    learn_rate = 0.0001 + NUM_WORKERS * 0.00006\n",
    "    adam_optim = tf.keras.optimizers.Adam(lr=learn_rate)\n",
    "    batch_size = 32* NUM_WORKERS\n",
    "    model_path_work2='/persist-log/weight_tfdl.h5'\n",
    "\n",
    "    callbacks = [tf.keras.callbacks.ModelCheckpoint(model_path_work2, save_weights_only=True, verbose=1)]\n",
    "    #X_train=np.array(X_train)\n",
    "    #print(type(X_train))\n",
    "    with strategy.scope():\n",
    "        Anchor = X_train[:, 0, :].reshape(-1, img_size, img_size, channel)\n",
    "        Positive = X_train[:, 1, :].reshape(-1, img_size, img_size, channel)\n",
    "        Negative = X_train[:, 2, :].reshape(-1, img_size, img_size, channel)\n",
    "        Y_dummy = np.empty(Anchor.shape[0])\n",
    "        model = training_model((img_size, img_size, channel), FREEZE_LAYERS, facenet_weight_path)\n",
    "        \n",
    "    model.fit(x=[Anchor, Positive, Negative],\n",
    "        y=Y_dummy,\n",
    "        # Anchor_test = X_test[:, 0, :].reshape(-1, img_size, img_size, channel)\n",
    "        # Positive_test = X_test[:, 1, :].reshape(-1, img_size, img_size, channel)\n",
    "        # Negative_test = X_test[:, 2, :].reshape(-1, img_size, img_size, channel)\n",
    "        # Y_dummy = np.empty(Anchor.shape[0])\n",
    "        # Y_dummy2 = np.empty((Anchor_test.shape[0], 1))\n",
    "        # validation_data=([Anchor_test,Positive_test,Negative_test],Y_dummy2),\n",
    "        # validation_split=0.2,\n",
    "        batch_size=batch_size,  # old setting: 32\n",
    "        # steps_per_epoch=(X_train.shape[0] // batch_size) + 1,\n",
    "        epochs=10,\n",
    "        callbacks=callbacks\n",
    "        )  \n",
    "    end = time.time()\n",
    "    start_time_float=float(start_time_string)\n",
    "    print('execution time = ', ((end - start_time_float)/60))\n",
    "    return [model_path_work2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "cd6214a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def distributed_training_worker3(start_time_string:str)->NamedTuple('Outputs',[('model_path_work3',str)]):\n",
    "    import numpy as np\n",
    "    import sys\n",
    "    import time\n",
    "    import tensorflow as tf\n",
    "    import json\n",
    "    import os\n",
    "    sys.path.append(\"./\")\n",
    "    sys.path.append(\"/persist-log\")\n",
    "    from config import img_size, channel, faces_data_dir, FREEZE_LAYERS, classify, facenet_weight_path\n",
    "    from inception_resnet_v1 import InceptionResNetV1\n",
    "    from itertools import permutations\n",
    "    from tqdm import tqdm\n",
    "    from tensorflow.keras import backend as K\n",
    "    from sklearn.manifold import TSNE\n",
    "    \n",
    "    #load data from pvc in the container\n",
    "    data = np.load('/persist-log/triplet-data.npz')\n",
    "    X_train, X_test = data['arr_0'], data['arr_1']\n",
    "\n",
    "    def training_model(in_shape,freeze_layers,weights_path):\n",
    "\n",
    "        def create_base_network(in_dims,freeze_layers,weights_path):\n",
    "            model = InceptionResNetV1(input_shape=in_dims, weights_path=weights_path)\n",
    "            print('layer length: ', len(model.layers))\n",
    "            for layer in model.layers[:freeze_layers]:\n",
    "                layer.trainable = False\n",
    "            for layer in model.layers[freeze_layers:]:\n",
    "                layer.trainable = True\n",
    "            return model\n",
    "        \n",
    "        def triplet_loss(y_true,y_pred,alpha=0.4):\n",
    "            total_lenght = y_pred.shape.as_list()[-1]\n",
    "            anchor = y_pred[:, 0:int(total_lenght * 1 / 3)]\n",
    "            positive = y_pred[:, int(total_lenght * 1 / 3):int(total_lenght * 2 / 3)]\n",
    "            negative = y_pred[:, int(total_lenght * 2 / 3):int(total_lenght * 3 / 3)]\n",
    "            # distance between the anchor and the positive\n",
    "            pos_dist = K.sum(K.square(anchor - positive), axis=1)\n",
    "            # distance between the anchor and the negative\n",
    "            neg_dist = K.sum(K.square(anchor - negative), axis=1)\n",
    "            # compute loss\n",
    "            basic_loss = pos_dist - neg_dist + alpha\n",
    "            loss = K.maximum(basic_loss, 0.0)\n",
    "            return loss\n",
    "        # define triplet input layers\n",
    "        anchor_input = tf.keras.layers.Input(in_shape, name='anchor_input')\n",
    "        positive_input = tf.keras.layers.Input(in_shape, name='positive_input')\n",
    "        negative_input = tf.keras.layers.Input(in_shape, name='negative_input')\n",
    "        Shared_DNN = create_base_network(in_shape, freeze_layers, weights_path)\n",
    "        # Shared_DNN.summary()\n",
    "        # encoded inputs\n",
    "        encoded_anchor = Shared_DNN(anchor_input)\n",
    "        encoded_positive = Shared_DNN(positive_input)\n",
    "        encoded_negative = Shared_DNN(negative_input)\n",
    "        # output\n",
    "        merged_vector = tf.keras.layers.concatenate([encoded_anchor, encoded_positive, encoded_negative],axis=-1,name='merged_layer')\n",
    "        model = tf.keras.Model(inputs=[anchor_input, positive_input, negative_input], outputs=merged_vector)\n",
    "        model.compile(\n",
    "            optimizer=adam_optim,\n",
    "            loss=triplet_loss,\n",
    "        )\n",
    "        return model\n",
    "    \n",
    "    \n",
    "    os.environ['TF_CONFIG'] = json.dumps({'cluster': {'worker': [\"pipeline-worker-1:3000\",\"pipeline-worker-2:3000\",\"pipeline-worker-3:3000\"]},'task': {'type': 'worker', 'index': 2}})\n",
    "\n",
    "\n",
    "    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(\n",
    "        tf.distribute.experimental.CollectiveCommunication.RING)\n",
    "    NUM_WORKERS = strategy.num_replicas_in_sync\n",
    "    print('=================\\r\\nWorkers: ' + str(NUM_WORKERS) + '\\r\\n=================\\r\\n')\n",
    "    learn_rate = 0.0001 + NUM_WORKERS * 0.00006\n",
    "    adam_optim = tf.keras.optimizers.Adam(lr=learn_rate)\n",
    "    batch_size = 32* NUM_WORKERS\n",
    "    model_path_work3='/persist-log/weight_tfdl.h5'\n",
    "    callbacks = [tf.keras.callbacks.ModelCheckpoint(model_path_work3, save_weights_only=True, verbose=1)]\n",
    "    #X_train=np.array(X_train)\n",
    "    #print(type(X_train))\n",
    "    with strategy.scope():\n",
    "        Anchor = X_train[:, 0, :].reshape(-1, img_size, img_size, channel)\n",
    "        Positive = X_train[:, 1, :].reshape(-1, img_size, img_size, channel)\n",
    "        Negative = X_train[:, 2, :].reshape(-1, img_size, img_size, channel)\n",
    "        Y_dummy = np.empty(Anchor.shape[0])\n",
    "        model = training_model((img_size, img_size, channel), FREEZE_LAYERS, facenet_weight_path)\n",
    "        \n",
    "    model.fit(x=[Anchor, Positive, Negative],\n",
    "        y=Y_dummy,\n",
    "        # Anchor_test = X_test[:, 0, :].reshape(-1, img_size, img_size, channel)\n",
    "        # Positive_test = X_test[:, 1, :].reshape(-1, img_size, img_size, channel)\n",
    "        # Negative_test = X_test[:, 2, :].reshape(-1, img_size, img_size, channel)\n",
    "        # Y_dummy = np.empty(Anchor.shape[0])\n",
    "        # Y_dummy2 = np.empty((Anchor_test.shape[0], 1))\n",
    "        # validation_data=([Anchor_test,Positive_test,Negative_test],Y_dummy2),\n",
    "        # validation_split=0.2,\n",
    "        batch_size=batch_size,  # old setting: 32\n",
    "        # steps_per_epoch=(X_train.shape[0] // batch_size) + 1,\n",
    "        epochs=10,\n",
    "        callbacks=callbacks\n",
    "        )  \n",
    "    end = time.time()\n",
    "    start_time_float=float(start_time_string)\n",
    "    print('execution time = ', ((end - start_time_float)/60))\n",
    "    return [model_path_work3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "6a661a86",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_prediction(model_path:str,model_path_work2:str,model_path_work3:str)->NamedTuple('Outputs',[('model_path',str)]):\n",
    "    from os import listdir\n",
    "    from os.path import isfile\n",
    "    import time\n",
    "    import numpy as np\n",
    "    import cv2\n",
    "    from sklearn.manifold import TSNE\n",
    "    from scipy.spatial import distance\n",
    "    import tensorflow as tf\n",
    "    import sys\n",
    "    sys.path.append(\"./\")\n",
    "    sys.path.append(\"/persist-log\")\n",
    "    sys.path.append(\"/facenet/test\")\n",
    "    from img_process import align_image, prewhiten\n",
    "    from triplet_training import create_base_network\n",
    "    from utils import scatter\n",
    "    from config import img_size, channel, classify, FREEZE_LAYERS, facenet_weight_path, faces_data_dir\n",
    "    anchor_input = tf.keras.Input((img_size, img_size, channel,), name='anchor_input')\n",
    "    Shared_DNN = create_base_network((img_size, img_size, channel), FREEZE_LAYERS, facenet_weight_path)\n",
    "    encoded_anchor = Shared_DNN(anchor_input)\n",
    "    \n",
    "    model = tf.keras.Model(inputs=anchor_input, outputs=encoded_anchor)\n",
    "    model.load_weights(model_path)\n",
    "    model.summary()\n",
    "    start = time.time()\n",
    "    def l2_normalize(x, axis=-1, epsilon=1e-10):\n",
    "        output = x / np.sqrt(np.maximum(np.sum(np.square(x), axis=axis, keepdims=True), epsilon))\n",
    "        return output\n",
    "\n",
    "\n",
    "    # Acquire embedding from image\n",
    "    def embedding_extractor(img_path):\n",
    "        img = cv2.imread(img_path)\n",
    "        aligned = align_image(img)\n",
    "        #cv2.imwrite(\"facenet/align/\"+\"_aligned.jpg\", aligned)\n",
    "        if aligned is not None:\n",
    "            aligned = aligned.reshape(-1, img_size, img_size, channel)\n",
    "\n",
    "            embs = l2_normalize(np.concatenate(model.predict(aligned)))\n",
    "            return embs\n",
    "        else:\n",
    "            print(img_path + ' is None')\n",
    "            return None\n",
    "        \n",
    "    testset_dir = 'facenet/test/'\n",
    "    items = listdir(testset_dir)\n",
    "\n",
    "    jpgsList = [x for x in items if isfile(testset_dir + x)]\n",
    "    foldersList = [x for x in items if not isfile(testset_dir + x)]\n",
    "\n",
    "    print(jpgsList)\n",
    "    print(foldersList)\n",
    "\n",
    "    acc_total = 0\n",
    "    for i, anch_jpg in enumerate(jpgsList):\n",
    "        anchor_path = testset_dir + anch_jpg\n",
    "        anch_emb = embedding_extractor(anchor_path)\n",
    "\n",
    "        for j, clt_folder in enumerate(foldersList):\n",
    "            clt_path = testset_dir + clt_folder + '/'\n",
    "            clt_jpgs = listdir(clt_path)\n",
    "            #print('anchor_path is :',anchor_path)\n",
    "            #print('clt_jpgs is :',clt_jpgs)\n",
    "            #print('clt_path is :',clt_path)\n",
    "\n",
    "            str = anch_jpg\n",
    "            computeType = 1 if clt_folder == str.replace('.jpg', '') else 0\n",
    "\n",
    "            loss = 0\n",
    "            if computeType == 1:\n",
    "                sum1 = 0\n",
    "                print('==============' + clt_folder + '&' + anch_jpg + '==============')\n",
    "                for k, clt_jpg in enumerate(clt_jpgs):\n",
    "                    clt_jpg_path = clt_path + clt_jpg\n",
    "                    clt_emb = embedding_extractor(clt_jpg_path)\n",
    "                    distanceDiff = distance.euclidean(anch_emb, clt_emb)  # calculate the distance\n",
    "                    #print('distance = ', distanceDiff)\n",
    "                    sum1 = distanceDiff + sum1\n",
    "                    loss = loss + 1 if distanceDiff >= 1 else loss\n",
    "\n",
    "                print(\"sum1\", sum1 / 50.0)\n",
    "                print('loss: ', loss)\n",
    "                accuracy = (len(clt_jpgs) - loss) / len(clt_jpgs)\n",
    "                print('accuracy: ', accuracy)\n",
    "                acc_total += accuracy\n",
    "            else:\n",
    "                print('==============' + clt_folder + '&' + anch_jpg + '==============')\n",
    "                sum2 = 0\n",
    "                for k, clt_jpg in enumerate(clt_jpgs):\n",
    "                    clt_jpg_path = clt_path + clt_jpg\n",
    "                    clt_emb = embedding_extractor(clt_jpg_path)\n",
    "                    distanceDiff = distance.euclidean(anch_emb, clt_emb)  # calculate the distance\n",
    "                    #print('distance = ', distanceDiff)\n",
    "                    loss = loss + 1 if distanceDiff < 1 else loss\n",
    "                    sum2 = distanceDiff + sum2\n",
    "                print(\"sum2\", sum2 / 50.0)\n",
    "                print('loss: ', loss)\n",
    "                accuracy = (len(clt_jpgs) - loss) / len(clt_jpgs)\n",
    "                print('accuracy: ', accuracy)\n",
    "                acc_total += accuracy\n",
    "\n",
    "            print('--acc_total', acc_total)\n",
    "\n",
    "    acc_mean = acc_total / 81 * 100\n",
    "    print('final acc++------: ', acc_mean)\n",
    "    end = time.time()\n",
    "    print ('execution time', (end - start))\n",
    "    \n",
    "    return [model_path]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "87cae256",
   "metadata": {},
   "outputs": [],
   "source": [
    "#serving\n",
    "def serving(model_path:str, log_folder:str):\n",
    "    from flask import Flask,render_template,url_for,request,redirect,make_response,jsonify\n",
    "    from werkzeug.utils import secure_filename\n",
    "    import os \n",
    "    import cv2\n",
    "    import sys\n",
    "    import time\n",
    "    import base64\n",
    "    import math\n",
    "    from datetime import timedelta\n",
    "    import numpy as np\n",
    "    from os import listdir\n",
    "    from os.path import isfile\n",
    "    from sklearn.manifold import TSNE\n",
    "    from scipy.spatial import distance\n",
    "    import tensorflow as tf\n",
    "    \n",
    "    sys.path.append(\"./\")\n",
    "    sys.path.append(\"/persist-log\")\n",
    "    sys.path.append(\"/templates\")\n",
    "    \n",
    "    from img_process import align_image, prewhiten\n",
    "    from triplet_training import create_base_network\n",
    "    from utils import scatter\n",
    "    from config import img_size, channel, classify, FREEZE_LAYERS, facenet_weight_path, faces_data_dir\n",
    "    serving_time = time.time\n",
    "    ALLOWED_EXTENSIONS = set(['jpg','JPG'])\n",
    "    \n",
    "\n",
    "    def allowed_file(filename):\n",
    "        return '.' in filename and filename.rsplit('.',1)[1] in ALLOWED_EXTENSIONS\n",
    "\n",
    "\n",
    "    def return_img_stream(img_local_path):\n",
    "        img_stream = ''\n",
    "        with open(img_local_path,'rb') as img_f:\n",
    "            img_stream = img_f.read()\n",
    "            img_stream = base64.b64encode(img_stream).decode()\n",
    "        return img_stream\n",
    "\n",
    "        # L2 normalization\n",
    "    def l2_normalize(x, axis=-1, epsilon=1e-10):\n",
    "        output = x / np.sqrt(np.maximum(np.sum(np.square(x), axis=axis, keepdims=True), epsilon))\n",
    "        return output\n",
    "    \n",
    "#--------------------------------------------------------------demo.py \n",
    "\n",
    "    # Acquire embedding from image\n",
    "    def embedding_extractor(img_path,model):\n",
    "        img = cv2.imread(img_path)\n",
    "        aligned = align_image(img)\n",
    "        #cv2.imwrite(\"facenet/align/\"+\"_aligned.jpg\", aligned)\n",
    "        if aligned is not None:\n",
    "            aligned = aligned.reshape(-1, img_size, img_size, channel)\n",
    "\n",
    "            embs = l2_normalize(np.concatenate(model.predict(aligned)))\n",
    "            return embs\n",
    "        else:\n",
    "            print(img_path + ' is None')\n",
    "            return None\n",
    "#-------------------------------------------------------------flask\n",
    "\n",
    "    app = Flask(__name__, template_folder=\"/templates\")\n",
    "\n",
    "    app.send_file_max_age_default = timedelta(seconds=1)\n",
    "    \n",
    "    @app.route('/upload',methods=['GET','POST'])\n",
    "\n",
    "    def upload():\n",
    "        img_stream = ''\n",
    "        loss = 0\n",
    "        distanceDiffbig = 0\n",
    "        distanceDiffsmall = 0\n",
    "        distance_sum = 0\n",
    "\n",
    "        face = ''\n",
    "        face2 = ''\n",
    "        face3 = ''\n",
    "        acc_mean = 0\n",
    "\n",
    "        distance_low1 = 0\n",
    "        distance_low2 = 0\n",
    "        distance_low3 = 0\n",
    "        distance_show1 = 2\n",
    "        distance_show2 = 2\n",
    "        distance_show3 = 2\n",
    "        \n",
    "        if request.method =='POST':\n",
    "            f = request.files['file']\n",
    "            user_input = request.form.get('name')\n",
    "            basepath = os.path.dirname(__file__)\n",
    "            sys.path.append('/facenet/test')\n",
    "            upload_path = os.path.join(basepath,'/facenet/test',secure_filename(f.filename))\n",
    "            print(basepath)\n",
    "            f.save(upload_path)\n",
    "            #start = time.time()\n",
    "            #model_path = '/persist-log/weight_tfdl.h5'\n",
    "            anchor_input = tf.keras.Input((img_size, img_size, channel,), name='anchor_input')\n",
    "            Shared_DNN = create_base_network((img_size, img_size, channel), FREEZE_LAYERS, facenet_weight_path)\n",
    "            encoded_anchor = Shared_DNN(anchor_input)\n",
    "\n",
    "            model = tf.keras.Model(inputs=anchor_input, outputs=encoded_anchor)\n",
    "            model.load_weights(model_path) #/persist-log\n",
    "            model.summary()\n",
    "\n",
    "            testset_dir = 'facenet/test/'\n",
    "            items = listdir(testset_dir)\n",
    "\n",
    "            jpgsList = [x for x in items if isfile(testset_dir + x)]\n",
    "            foldersList = [x for x in items if not isfile(testset_dir + x)]\n",
    "\n",
    "            print(jpgsList)\n",
    "            print(foldersList)\n",
    "            \n",
    "            acc_total = 0\n",
    "            img_stream = return_img_stream(upload_path)\n",
    "            for i, anch_jpg in enumerate(jpgsList):\n",
    "                #anchor_path = testset_dir + anch_jpg\n",
    "                anch_emb = embedding_extractor(upload_path,model)\n",
    "                \n",
    "                for j, clt_folder in enumerate(foldersList):\n",
    "                    clt_path = testset_dir + clt_folder + '/'\n",
    "                    clt_jpgs = listdir(clt_path)\n",
    "                    str = anch_jpg\n",
    "                    print('==============' + clt_folder + '&' + anch_jpg + '==============')\n",
    "    \n",
    "                    for k, clt_jpg in enumerate(clt_jpgs):\n",
    "                        clt_jpg_path = clt_path + clt_jpg\n",
    "                        clt_emb = embedding_extractor(clt_jpg_path,model)\n",
    "                        distanceDiff = distance.euclidean(anch_emb, clt_emb)  # calculate the distance\n",
    "                        distance_sum=distance_sum + distanceDiff\n",
    "    \n",
    "                        if distanceDiff >= 1:\n",
    "                            distanceDiffbig = distanceDiffbig + 1\n",
    "                    \n",
    "                        else:\n",
    "                            distanceDiffsmall = distanceDiffsmall + 1\n",
    "                        \n",
    "                        if distanceDiffbig >= distanceDiffsmall :\n",
    "                            loss = distanceDiffsmall\n",
    "                        \n",
    "                        else:\n",
    "                            loss = distanceDiffbig\n",
    "                            \n",
    "                    distance_sum=distance_sum / 16  \n",
    "                \n",
    "                    if distance_sum < distance_show3: \n",
    "    \n",
    "                        if distance_sum < distance_show2:\n",
    "            \n",
    "                            if distance_sum < distance_show1:\n",
    "                                distance_show1 = distance_sum\n",
    "                                distance_low1 = distance_sum\n",
    "                                face =  clt_folder\n",
    "                            \n",
    "                            else:\n",
    "                                distance_low2 = distance_sum\n",
    "                                distance_show2 = distance_sum\n",
    "                                face2 =  clt_folder\n",
    "                    \n",
    "                        else:\n",
    "                            distance_show3 = distance_sum\n",
    "                            distance_low3 = distance_sum\n",
    "                            face3 = clt_folder\n",
    "                    else:\n",
    "                        distanceDiff = distanceDiff\n",
    "                        \n",
    "                    print('distance sum is:', distance_sum)\n",
    "                    print('distanceDiffsmall = ', distanceDiffsmall)\n",
    "                    print('distanceDiffbig = ', distanceDiffbig)\n",
    "                    print( distanceDiff)\n",
    "                    \n",
    "                    distance_sum = 0\n",
    "                    distanceDiffsmall = 0\n",
    "                    distanceDiffbig = 0\n",
    "                    print('loss: ', loss)\n",
    "                    accuracy = (len(clt_jpgs) - loss) / len(clt_jpgs)\n",
    "                    acc_total += accuracy\n",
    "                    print('face = ', face)\n",
    "                    print('The first is:',face,'distance is ',distance_low1)\n",
    "                    print('The Second is:',face2,'distance is ',distance_low2)\n",
    "                    print('The third is:',face3,'distance is ',distance_low3)\n",
    "                    \n",
    "                distance_low1 = round(distance_low1,2) \n",
    "                distance_low2 = round(distance_low2,2)\n",
    "                distance_low3 = round(distance_low3,2)\n",
    "            acc_mean = acc_total / 9 * 100\n",
    "            acc_mean = round(acc_mean,2)\n",
    "            print('final acc++------: ', acc_mean)\n",
    "            os.remove(upload_path)\n",
    "            #end = time.time()\n",
    "            #print ('execution time', (end - serving_time))\n",
    "            \n",
    "        return render_template('upload.html',img_stream = img_stream, face = face , face2 = face2 , face3 = face3 , distance_low1 = distance_low1, distance_low2 = distance_low2 , distance_low3 = distance_low3, acc_mean = acc_mean )\n",
    "    \n",
    "    if __name__ == '__main__':\n",
    "        app.run(host = '127.0.0.1',port=8987,debug=True)\n",
    "        \n",
    "    return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "f8a12f54",
   "metadata": {},
   "outputs": [],
   "source": [
    "import kfp.dsl as dsl\n",
    "import kfp.components as components\n",
    "from typing import NamedTuple\n",
    "import kfp\n",
    "from kfp import dsl\n",
    "from kfp.components import func_to_container_op, InputPath, OutputPath\n",
    "from kubernetes.client.models import V1ContainerPort\n",
    "@dsl.pipeline(\n",
    "   name='triplet_training pipeline',\n",
    "   description='triplet training test.'\n",
    ")\n",
    "def triplet_training_pipeline():\n",
    "\n",
    "    log_folder = '/persist-log'\n",
    "    pvc_name = \"triplet-trainaing-pvc\"\n",
    "    \n",
    "    #label name\n",
    "    name=\"pod-name\"\n",
    "    value1=\"worker-1\" # selector pod-name: worker-1\n",
    "    value2=\"worker-2\" # selector pod-name: worker-2\n",
    "    value3=\"worker-3\" # selector pod-name: worker-3\n",
    "    \n",
    "    container_port=3000\n",
    "    \n",
    "    #select node\n",
    "    label_name=\"disktype\"\n",
    "    label_value1=\"worker-1\"   \n",
    "    label_value2=\"worker-2\"   \n",
    "    label_value3=\"worker-3\"   \n",
    "    \n",
    "    vop = dsl.VolumeOp(\n",
    "        name=pvc_name,\n",
    "        resource_name=\"newpvc\",\n",
    "        storage_class=\"managed-nfs-storage\",\n",
    "        size=\"30Gi\",\n",
    "        modes=dsl.VOLUME_MODE_RWM\n",
    "    )\n",
    "    \n",
    "    load_data_op=func_to_container_op(\n",
    "        func=load_data,\n",
    "        base_image=\"mike0355/k8s-facenet-distributed-training:4\",  \n",
    "    )\n",
    "        \n",
    "    distributed_training_worker1_op=func_to_container_op(\n",
    "        func=distributed_training_worker1,\n",
    "        base_image=\"mike0355/k8s-facenet-distributed-training:4\"\n",
    "    )\n",
    "    \n",
    "    distributed_training_worker2_op=func_to_container_op(\n",
    "        func=distributed_training_worker2,\n",
    "        base_image=\"mike0355/k8s-facenet-distributed-training:4\"\n",
    "    )\n",
    "    \n",
    "    distributed_training_worker3_op=func_to_container_op(\n",
    "        func=distributed_training_worker3,\n",
    "        base_image=\"mike0355/k8s-facenet-distributed-training:4\"\n",
    "    )  \n",
    "    \n",
    "    model_prediction_op=func_to_container_op(\n",
    "        func=model_prediction,\n",
    "        base_image=\"mike0355/k8s-facenet-distributed-training:4\"\n",
    "    )\n",
    "    \n",
    "    serving_op=func_to_container_op(\n",
    "        func=serving,\n",
    "        base_image=\"mike0355/k8s-facenet-serving:3\"\n",
    "    )\n",
    "\n",
    "  #----------------------------------------------------------task  \n",
    "    load_data_task=load_data_op(log_folder).add_pvolumes({\n",
    "        log_folder:vop.volume,\n",
    "    })\n",
    "    \n",
    "    distributed_training_worker1_task=distributed_training_worker1_op(load_data_task.outputs['start_time_string']).add_pvolumes({  #woker1\n",
    "        log_folder:vop.volume,\n",
    "    }).add_pod_label(name,value1).add_node_selector_constraint(label_name,label_value1).add_port(V1ContainerPort(container_port=3000,host_port=3000))\n",
    "        \n",
    "    \n",
    "    distributed_training_worker2_task=distributed_training_worker2_op(load_data_task.outputs['start_time_string']).add_pvolumes({  #woker2\n",
    "        log_folder:vop.volume,\n",
    "    }).add_pod_label(name,value2).add_port(V1ContainerPort(container_port=3000,host_port=3000)).add_node_selector_constraint(label_name,label_value2)\n",
    "    \n",
    "    \n",
    "    distributed_training_worker3_task=distributed_training_worker3_op(load_data_task.outputs['start_time_string']).add_pvolumes({  #woker3\n",
    "        log_folder:vop.volume,\n",
    "    }).add_pod_label(name,value3).add_port(V1ContainerPort(container_port=3000,host_port=3000)).add_node_selector_constraint(label_name,label_value3)\n",
    "    \n",
    "         \n",
    "    model_prediction_task=model_prediction_op(distributed_training_worker1_task.outputs['model_path'],distributed_training_worker2_task.outputs['model_path_work2'], \n",
    "        distributed_training_worker3_task.outputs['model_path_work3']).add_pvolumes({\n",
    "        log_folder:vop.volume,\n",
    "    })\n",
    "    \n",
    "    \n",
    "    serving_task=serving_op(model_prediction_task.outputs['model_path'], log_folder).add_pvolumes({\n",
    "        log_folder:vop.volume,\n",
    "    })\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "70db3347",
   "metadata": {},
   "outputs": [],
   "source": [
    "kfp.compiler.Compiler().compile(triplet_training_pipeline, 'distributed-training-1011-final.yaml')\n",
    "#kfp.compiler.Compiler().compile(triplet_training_pipeline, 'load-data0902.zip')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "071e9b42",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "860d92e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "                \n",
    "                \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5df6fcc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "283c28a6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
